#!/usr/bin/env python
# -*- coding: utf-8 -*-


import tensorflow as tf
import matplotlib as mpl
import matplotlib.pyplot as plt
import DonaldDuckDataset
import DonaldDuckModel
from DonaldDuckAttack import DonaldDuckAttack, adp_epsilons
import numpy as np
import DonaldDuckConv
# import eagerpy as ep
from Get_Detector import detectModels
import DonaldDuckFunc
import os
from DonaldDuckFunc import project, random_uniform, normal_gradient

class PGD(DonaldDuckAttack):

    def create_adversarial_pattern(self, epsilon, steps, lp, targeted=False, random_start=True, stepsize = 1):
        t_data = self.images
        raw_data=tf.constant(self.images)
        loss_object = tf.keras.losses.CategoricalCrossentropy()
        acc=[]
        acc_step=[]

        if targeted:
            direc=-1
        else:
            direc=1

        if random_start:
            t_data = random_uniform(t_data, epsilon, lp)

        for idx in range(steps):
            t_data = tf.constant(t_data)

            with tf.GradientTape() as tape:
                tape.watch(t_data)

                prediction = self.model.model(t_data)
                
                loss = loss_object(self.labels, prediction)
            gradient = tape.gradient(loss, t_data)

            gradient=normal_gradient(gradient, lp=lp)#.numpy()
            perturbation = t_data+gradient * stepsize *direc-raw_data

            perturbation = project(perturbation, epsilon, lp)
            t_data = perturbation+raw_data
            t_data=tf.clip_by_value(t_data, 0, 1)
            
            if idx%max(5,int(steps/6))==0:
                self.adv_examples = t_data.numpy()
                acc.append(self.test_adv())
                acc_step.append(idx)
        return acc, acc_step

class Adapt_Attack1(DonaldDuckAttack):

    def create_adversarial_pattern(self, epsilon, steps, lp, targeted=False, random_start=True, stepsize = 1, alpha=100, dgan=None):

        loss_object = tf.keras.losses.CategoricalCrossentropy()

        t_data = self.images

        if targeted:
            direc=1
        else:
            direc=-1

        if random_start:
            t_data = random_uniform(t_data, epsilon, lp)

        t_data = tf.Variable(t_data)
        raw_data=tf.constant(self.images)
        opt = tf.keras.optimizers.SGD(learning_rate=stepsize)

        for _ in range(steps):
            with tf.GradientTape() as tape:
                tape.watch(t_data)

                prediction = self.model.model(t_data)
                
                loss = loss_object(self.labels, prediction)
            gradient = tape.gradient(loss, t_data)
            gradient=normal_gradient(gradient, lp=lp)
            opt.apply_gradients(zip([gradient*direc], [t_data]))
            t_data = project((t_data-raw_data), epsilon, lp)+raw_data
            t_data=tf.clip_by_value(t_data, 0, 1)
            t_data = tf.Variable(t_data)
        
        self.adv_examples = t_data.numpy()

class Adapt_Attack2(DonaldDuckAttack):

    def create_adversarial_pattern(self, dgan, epsilon, steps, lp, random_start=True, stepsize = 1, alpha=100):

        loss_CE = tf.keras.losses.CategoricalCrossentropy()
        loss_MSE = tf.keras.losses.MeanSquaredError()
        acc=[]
        acc_step=[]
        t_data = self.images

        if random_start:
            t_data = random_uniform(t_data, epsilon, lp)

        t_data = tf.Variable(t_data)
        raw_data=tf.constant(self.images)
        opt = tf.keras.optimizers.SGD(learning_rate=stepsize)

        for idx in range(steps):

            with tf.GradientTape() as tape:
                tape.watch(t_data)

                prediction = self.model.model(t_data)
                
                loss1 = loss_CE(self.labels, prediction)
            gradient1 = tape.gradient(loss1, t_data)

            with tf.GradientTape() as tape:
                tape.watch(t_data)
                logits=dgan.target(t_data)
                semantics= dgan.encoder(t_data)
                r_imgs = dgan.decoder([semantics, logits])
                loss2 = loss_MSE(t_data, r_imgs)
            gradient2 = tape.gradient(loss2, t_data)

            gradient=(gradient2*alpha)-gradient1
            gradient=normal_gradient(gradient, lp=lp)
            opt.apply_gradients(zip([gradient], [t_data]))
            t_data = project((t_data-raw_data), epsilon, lp)+raw_data
            t_data=tf.clip_by_value(t_data, 0, 1)
            t_data = tf.Variable(t_data)

            if idx%max(5,int(steps/6))==0:
                self.adv_examples = t_data.numpy()
                acc.append(self.test_adv())
                acc_step.append(idx)
        return acc, acc_step
        

if __name__ == '__main__':
    
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
    physical_devices = tf.config.list_physical_devices('GPU')
    try:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        assert tf.config.experimental.get_memory_growth(physical_devices[0])
    except:
        pass

    plt.style.use('seaborn-whitegrid')
    
    dataset=DonaldDuckDataset.CIFAR10(standardization=False)

    if dataset.name=='cifar10':
        tar_model = DonaldDuckConv.DonaldDuckVGG16(
            dataset,
            build_dir=False
        )
        tar_model.setModel()
        tar_model.load_model(
            weights_path=r'savedModels//' + 'VGG16' +
                        '_' + dataset.name+ '.h5'
        )
    else:
        conv_layers_num = 5
        init_filters = 32
        tar_model = DonaldDuckConv.DonaldDuckCNN(
            dataset,
            build_dir=False
        )
        tar_model.setModel(
            conv_layers_num=conv_layers_num,
            filters=init_filters,
            kernel_size=(3,3)
        )
        tar_model.load_model(
            weights_path=r'savedModels//'+'CNN'+'_'+dataset.name
                                    +'_'+str(conv_layers_num)
                                    +'_'+str(init_filters)+'.h5'
        )
    dms='DRR'
    dgan=detectModels[dms]['model'](
        dataset,
        batch_size=128,
        epochs=50,
        kernel_size=(3,3),
        # build_dir=False
    )
    dgan.setModel(
        tar_model=tar_model,
        skip_flag=False,
    )
    weight_date=detectModels[dms][dataset.name]['dir']
    weight_idx=str(detectModels[dms][dataset.name]['idx'])
    model_name=detectModels[dms][dataset.name]['name']
    dgan.loadWeights(
            encoder_weight_path='savedModels//'+weight_date+
                                '//weight_encoder_'+model_name+
                                weight_idx+'.h5',
            decoder_weight_path='savedModels//'+weight_date+
                                '//weight_decoder_'+model_name+
                                weight_idx+'.h5',
            disI_weight_path='savedModels//'+weight_date+
                            '//weight_dis_'+model_name+
                            weight_idx+'.h5',
    )
    
    random_start=True
    steps=10
    rounds=25
    advNum=400
    stepsize=0.1
    ext='.pdf'

    for lp in [np.inf, 2, 1]:#,
        alphas=[1e2, 1e3, 1e4, 1e5]
        epsilons=adp_epsilons[dataset.name][lp]
        if lp==1:
            steps=150
        for epsilon in epsilons:
            dl_raw, dl_adv=[], []
            
            pgd_acc=[]
            adv_exps=np.zeros((advNum*rounds,)+dataset.input_shape)
            raw_exps=np.zeros((advNum*rounds,)+dataset.input_shape)
            for idx in range(rounds):
                fa=PGD(
                    model=tar_model,
                    advNum=advNum
                )
                accs, pgd_steps=fa.create_adversarial_pattern(
                    epsilon=epsilon,
                    steps=steps,
                    lp=lp,
                    targeted=False,
                    random_start=random_start,
                    stepsize=stepsize
                )
                pgd_acc.append(accs)
                adv_exps[advNum*idx:advNum*(idx+1)]=fa.adv_examples
                raw_exps[advNum*idx:advNum*(idx+1)]=fa.images
            
            dgan.testAdv=adv_exps
            dgan.testClean=raw_exps
            pgd_dis_raws, pgd_dis_advs, _, _=dgan.detect_adv(
                img_name='Adapt' + '_' + 'Linf' + '_' + str(epsilon)+'_'+DonaldDuckFunc.getTimeStamp(),
                plot_flag=False
            )

            dl_raw.append(pgd_dis_raws)
            dl_adv.append(pgd_dis_advs)

            pgd_acc=np.array(pgd_acc)
            pgd_accs=np.mean(pgd_acc, axis=0)
            pgd_steps[0]=1
            
            adp_accs=[]
            for alpha in alphas:
                print(str(alpha)+' '+str(epsilon))
                adp_acc=[]
                
                # adp_dis_raws, adp_dis_advs=[], []
                adv_exps=np.zeros((advNum*rounds,)+dataset.input_shape)
                raw_exps=np.zeros((advNum*rounds,)+dataset.input_shape)
                for idx in range(rounds):
                    fa=Adapt_Attack2(
                        model=tar_model,
                        advNum=advNum
                    )
                    accs, adp_steps=fa.create_adversarial_pattern(
                        alpha=alpha,
                        dgan=dgan,
                        epsilon=epsilon,
                        steps=steps,
                        lp=lp,
                        random_start=random_start,
                        stepsize=1
                    )
                    adp_acc.append(accs)
                    adv_exps[advNum*idx:advNum*(idx+1)]=fa.adv_examples
                    raw_exps[advNum*idx:advNum*(idx+1)]=fa.images
                
                dgan.testAdv=adv_exps
                dgan.testClean=raw_exps
                adp_dis_raws, adp_dis_advs, _, _=dgan.detect_adv(
                    img_name='Adapt' + '_' + 'Linf' + '_' + str(epsilon)+'_'+DonaldDuckFunc.getTimeStamp(),
                    plot_flag=False
                )
                dl_raw.append(adp_dis_raws)
                dl_adv.append(adp_dis_advs)

                adp_acc=np.array(adp_acc)
                adp_acc=np.mean(adp_acc, axis=0)

                adp_steps[0]=1
                adp_accs.append(adp_acc)

            for idx, alpha in enumerate(alphas):
                plt.plot(adp_steps, adp_accs[idx], marker=(4+idx),label='adp_'+str(10)+'^'+str(int(np.log10(alpha))))

            plt.plot(pgd_steps, pgd_accs, marker=(5+len(alphas)),label='PGD')
            plt.ylabel('accuracy', fontsize=22)
            plt.xlabel('steps', fontsize=22)
            plt.xticks(fontsize=22)
            plt.yticks(fontsize=22)
            plt.legend(fontsize=22)
            plt.tight_layout()
            plt.savefig(dgan.saveImgPath + '/' +dataset.name+ '_acc_'+ str(lp)+'_'+ str(epsilon)+ ext)#+'_'+ str(alpha)
            #plt.show()
            plt.clf()

            bp_labels=['PGD']+['adp\n'+str(10)+'^'+str(int(np.log10(a))) for a in alphas]
            bp_raw=plt.boxplot(
                dl_raw,
                patch_artist=True,
                showfliers=False,
                labels=bp_labels
            )
            for box in bp_raw['boxes']:
                box.set(color='blue', linewidth=2)
                box.set(facecolor = 'blue' )
                box.set(alpha = 0.3)
                
            bp_adv=plt.boxplot(
                dl_adv,
                patch_artist=True,
                showfliers=False,
                labels=bp_labels,
            )
            for box in bp_adv['boxes']:
                box.set(color='red', linewidth=2)
                box.set(facecolor = 'red' )
                box.set(alpha = 0.3)
            plt.ylabel('Reconstruction Error',fontsize=22)
            plt.xticks(fontsize=22)
            plt.yticks(fontsize=22)
            plt.savefig(dgan.saveImgPath + '/' +dataset.name+ '_box_'+ str(lp)+'_'+ str(epsilon)+ ext)
            #plt.show()
            plt.clf()